Attention
多头缩放点积注意力机制(Scaled Dot-Product Attention)
\[\text{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V\]
- 输入:
Q - 查询矩阵地址(行优先),形状 \([B, H, L, D]\) 展平。
K - 键矩阵地址(行优先),形状 \([B, H, L, D]\) 展平。
V - 值矩阵地址(行优先),形状 \([B, H, L, D]\) 展平。
params - 其余参数打包成数组。
core_mask(可选) - 核掩码(仅适用于共享存储版本)。
- 输出:
output - 输出地址(行优先),形状 \([B, H, L, D]\) 展平。
- 支持平台:
FT78NEMT7004
备注
当前实现基于 fp32;输入/中间/输出缓冲区不应重叠。
内存布局为行优先(row-major)。
参数数组结构:
1int params[10];
2params[0] = batch_size; 批大小
3params[1] = seq_len; 序列长度
4params[2] = head_num; 多头数量
5params[3] = head_dim; 每头通道维数
6params[4] = QK地址的低32位;
7params[5] = QK地址的高32位;
8params[6] = 中间缓冲区地址的低32位;
9params[7] = 中间缓冲区地址的高32位;
共享存储版本:
-
void fp_attention_s(float *Q, float *K, float *V, float *output, int *params, int core_mask)
C调用示例:
1#include <stdio.h> 2 3int main(int argc, char* argv[]) { 4 int B = 2, L = 128, H = 8, D = 64; 5 float *Q = (float *)0xA0000000; // DDR 6 float *K = (float *)0xA1000000; // DDR 7 float *V = (float *)0xA2000000; // DDR 8 float *O = (float *)0xA3000000; // DDR 9 float *QK = (float *)0xA4000000; // DDR 10 float *SM = (float *)0xA5000000; // DDR 11 int core_mask = 0xff; 12 int params[10]; 13 params[0] = B; 14 params[1] = L; 15 params[2] = H; 16 params[3] = D; 17 params[4] = (int) (uint32_t) (uintptr_t) QK; 18 params[5] = (int) (uint32_t) (uintptr_t) QK >> 32; 19 params[6] = (int) (uint32_t) (uintptr_t) SM; 20 params[7] = (int) (uint32_t) (uintptr_t) SM >> 32; 21 fp_attention_s(Q, K, V, O, params, core_mask); 22 return 0; 23}
私有存储版本:
-
void fp_attention_p(float *Q, float *K, float *V, float *output, int *params)
C调用示例:
1#include <stdio.h> 2 3int main(int argc, char* argv[]) { 4 int B = 1, L = 64, H = 4, D = 32; 5 float *Q = (float *)0x10000000; // L2 6 float *K = (float *)0x10040000; // L2 7 float *V = (float *)0x10080000; // L2 8 float *O = (float *)0x100C0000; // L2 9 float *QK = (float *)0x10100000; // L2 10 float *SM = (float *)0x10200000; // L2 11 int params[10]; 12 params[0] = B; 13 params[1] = L; 14 params[2] = H; 15 params[3] = D; 16 params[4] = (int) (uint32_t) (uintptr_t) QK; 17 params[5] = (int) (uint32_t) (uintptr_t) QK >> 32; 18 params[6] = (int) (uint32_t) (uintptr_t) SM; 19 params[7] = (int) (uint32_t) (uintptr_t) SM >> 32; 20 fp_attention_p(Q, K, V, O, params); 21 return 0; 22}